/* Copyright 2021 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_PROFILER_GPU_MOCK_CUPTI_H_
#define XLA_BACKENDS_PROFILER_GPU_MOCK_CUPTI_H_

#include <stddef.h>
#include <stdint.h>

#include <cstdint>

#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti_profiler_target.h"
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti_target.h"
#include "xla/backends/profiler/gpu/cupti_interface.h"
#include "tsl/platform/test.h"

#if CUPTI_PM_SAMPLING_SUPPORTED  // Defined in cupti_interface.h
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti_pmsampling.h"
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti_profiler_host.h"
#endif

namespace xla {
namespace profiler {

// A mock object automatically generated by gmock_gen.py.
class MockCupti : public xla::profiler::CuptiInterface {
 public:
  MOCK_METHOD(CUptiResult, ActivityDisable, (CUpti_ActivityKind kind),
              (override));
  MOCK_METHOD(CUptiResult, ActivityEnable, (CUpti_ActivityKind kind),
              (override));
  MOCK_METHOD(CUptiResult, ActivityFlushAll, (uint32_t flag), (override));
  MOCK_METHOD(CUptiResult, ActivityGetNextRecord,
              (uint8_t* buffer, size_t valid_buffer_size_bytes,
               CUpti_Activity** record),
              (override));
  MOCK_METHOD(CUptiResult, ActivityGetNumDroppedRecords,
              (CUcontext context, uint32_t stream_id, size_t* dropped),
              (override));
  MOCK_METHOD(CUptiResult, ActivityConfigureUnifiedMemoryCounter,
              (CUpti_ActivityUnifiedMemoryCounterConfig * config,
               uint32_t count),
              (override));
  MOCK_METHOD(CUptiResult, ActivityRegisterCallbacks,
              (CUpti_BuffersCallbackRequestFunc func_buffer_requested,
               CUpti_BuffersCallbackCompleteFunc func_buffer_completed),
              (override));
  MOCK_METHOD(CUptiResult, ActivityUsePerThreadBuffer, (), (override));
  MOCK_METHOD(CUptiResult, SetActivityFlushPeriod, (uint32_t period_ms),
              (override));
  MOCK_METHOD(CUptiResult, GetDeviceId, (CUcontext context, uint32_t* deviceId),
              (override));
  MOCK_METHOD(CUptiResult, GetTimestamp, (uint64_t* timestamp), (override));
  MOCK_METHOD(CUptiResult, Finalize, (), (override));
  MOCK_METHOD(CUptiResult, EnableCallback,
              (uint32_t enable, CUpti_SubscriberHandle subscriber,
               CUpti_CallbackDomain domain, CUpti_CallbackId cbid),
              (override));
  MOCK_METHOD(CUptiResult, EnableDomain,
              (uint32_t enable, CUpti_SubscriberHandle subscriber,
               CUpti_CallbackDomain domain),
              (override));
  MOCK_METHOD(CUptiResult, Subscribe,
              (CUpti_SubscriberHandle * subscriber, CUpti_CallbackFunc callback,
               void* userdata),
              (override));
  MOCK_METHOD(CUptiResult, Unsubscribe, (CUpti_SubscriberHandle subscriber),
              (override));
  MOCK_METHOD(CUptiResult, GetResultString,
              (CUptiResult result, const char** str), (override));

  MOCK_METHOD(CUptiResult, GetContextId,
              (CUcontext context, uint32_t* context_id), (override));

  MOCK_METHOD(CUptiResult, GetStreamIdEx,
              (CUcontext context, CUstream stream, uint8_t per_thread_stream,
               uint32_t* stream_id),
              (override));

  MOCK_METHOD(CUptiResult, GetGraphId, (CUgraph graph, uint32_t* graph_id),
              (override));

  MOCK_METHOD(CUptiResult, GetGraphNodeId, (CUgraphNode node, uint64_t* nodeId),
              (override));

  MOCK_METHOD(CUptiResult, SetThreadIdType, (CUpti_ActivityThreadIdType type),
              (override));

  MOCK_METHOD(CUptiResult, ActivityEnableHWTrace, (bool enable), (override));

  MOCK_METHOD(CUptiResult, GetGraphExecId,
              (CUgraphExec graph_exec, uint32_t* graph_id), (override));

  // Profiler Host APIs
  MOCK_METHOD(CUptiResult, ProfilerHostInitialize,
              (CUpti_Profiler_Host_Initialize_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerHostDeinitialize,
              (CUpti_Profiler_Host_Deinitialize_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetSupportedChips,
              (CUpti_Profiler_Host_GetSupportedChips_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetBaseMetrics,
              (CUpti_Profiler_Host_GetBaseMetrics_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetSubMetrics,
              (CUpti_Profiler_Host_GetSubMetrics_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetMetricProperties,
              (CUpti_Profiler_Host_GetMetricProperties_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetRangeName,
              (CUpti_Profiler_Host_GetRangeName_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerHostEvaluateToGpuValues,
              (CUpti_Profiler_Host_EvaluateToGpuValues_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerHostConfigAddMetrics,
              (CUpti_Profiler_Host_ConfigAddMetrics_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetConfigImageSize,
              (CUpti_Profiler_Host_GetConfigImageSize_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetConfigImage,
              (CUpti_Profiler_Host_GetConfigImage_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetNumOfPasses,
              (CUpti_Profiler_Host_GetNumOfPasses_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerHostGetMaxNumHardwareMetricsPerPass,
              (CUpti_Profiler_Host_GetMaxNumHardwareMetricsPerPass_Params *
               params),
              (override));

  // Profiler Target APIs
  MOCK_METHOD(CUptiResult, ProfilerInitialize,
              (CUpti_Profiler_Initialize_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerDeInitialize,
              (CUpti_Profiler_DeInitialize_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerCounterDataImageCalculateSize,
              (CUpti_Profiler_CounterDataImage_CalculateSize_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerCounterDataImageInitialize,
              (CUpti_Profiler_CounterDataImage_Initialize_Params * params),
              (override));
  MOCK_METHOD(
      CUptiResult, ProfilerCounterDataImageCalculateScratchBufferSize,
      (CUpti_Profiler_CounterDataImage_CalculateScratchBufferSize_Params *
       params),
      (override));
  MOCK_METHOD(CUptiResult, ProfilerCounterDataImageInitializeScratchBuffer,
              (CUpti_Profiler_CounterDataImage_InitializeScratchBuffer_Params *
               params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerBeginSession,
              (CUpti_Profiler_BeginSession_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerEndSession,
              (CUpti_Profiler_EndSession_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerSetConfig,
              (CUpti_Profiler_SetConfig_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerUnsetConfig,
              (CUpti_Profiler_UnsetConfig_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerBeginPass,
              (CUpti_Profiler_BeginPass_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerEndPass,
              (CUpti_Profiler_EndPass_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerEnableProfiling,
              (CUpti_Profiler_EnableProfiling_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerDisableProfiling,
              (CUpti_Profiler_DisableProfiling_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerIsPassCollected,
              (CUpti_Profiler_IsPassCollected_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerFlushCounterData,
              (CUpti_Profiler_FlushCounterData_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerPushRange,
              (CUpti_Profiler_PushRange_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerPopRange,
              (CUpti_Profiler_PopRange_Params * params), (override));
  MOCK_METHOD(CUptiResult, ProfilerGetCounterAvailability,
              (CUpti_Profiler_GetCounterAvailability_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, ProfilerDeviceSupported,
              (CUpti_Profiler_DeviceSupported_Params * params), (override));

  // PM Sampling APIs
  MOCK_METHOD(CUptiResult, PmSamplingSetConfig,
              (CUpti_PmSampling_SetConfig_Params * params), (override));
  MOCK_METHOD(CUptiResult, PmSamplingEnable,
              (CUpti_PmSampling_Enable_Params * params), (override));
  MOCK_METHOD(CUptiResult, PmSamplingDisable,
              (CUpti_PmSampling_Disable_Params * params), (override));
  MOCK_METHOD(CUptiResult, PmSamplingStart,
              (CUpti_PmSampling_Start_Params * params), (override));
  MOCK_METHOD(CUptiResult, PmSamplingStop,
              (CUpti_PmSampling_Stop_Params * params), (override));
  MOCK_METHOD(CUptiResult, PmSamplingDecodeData,
              (CUpti_PmSampling_DecodeData_Params * params), (override));
  MOCK_METHOD(CUptiResult, PmSamplingGetCounterAvailability,
              (CUpti_PmSampling_GetCounterAvailability_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, PmSamplingGetCounterDataSize,
              (CUpti_PmSampling_GetCounterDataSize_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, PmSamplingCounterDataImageInitialize,
              (CUpti_PmSampling_CounterDataImage_Initialize_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, PmSamplingGetCounterDataInfo,
              (CUpti_PmSampling_GetCounterDataInfo_Params * params),
              (override));
  MOCK_METHOD(CUptiResult, PmSamplingCounterDataGetSampleInfo,
              (CUpti_PmSampling_CounterData_GetSampleInfo_Params * params),
              (override));

  MOCK_METHOD(CUptiResult, DeviceGetChipName,
              (CUpti_Device_GetChipName_Params * params), (override));

  MOCK_METHOD(void, CleanUp, (), (override));
  MOCK_METHOD(bool, Disabled, (), (const, override));
};

}  // namespace profiler
}  // namespace xla

#endif  // XLA_BACKENDS_PROFILER_GPU_MOCK_CUPTI_H_
